from typing import Optional
import torch
import torch.nn as nn


class NeuralNet(nn.Module):
    """Implements an MLP"""

    def __init__(self, 
                 input_dim: int,
                 output_dim: int,
                 hidden_dim: Optional[int] = 64,
                 n_hidden_layers: Optional[int] = 2,
                 torso: Optional[nn.Module] = None): 
                 
        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.n_hidden_layers = n_hidden_layers

        if torso is None:
            assert (
                n_hidden_layers is not None and n_hidden_layers > 0
            ), "n_hidden_layers must be >=0"
            
            # Create the torso of the NN
            self.torso = [nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Dropout(0.2)]
            for _ in range(n_hidden_layers - 1):
                self.torso.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(0.2)])
            self.torso = nn.Sequential(*self.torso)
            self.torso.hidden_dim = hidden_dim
        else:
            self.torso = torso
            assert self.torso[0].in_features == input_dim, "Input dim does not match torso input dim"

        self.last_layer = nn.Linear(self.torso.hidden_dim, output_dim)

    def forward(self, x, head_index=None):
        x = self.torso(x)

        return self.last_layer(x)
    

class MultiHeadedMLP(nn.Module):
    """Implements an MLP with multiple heads"""

    def __init__(self, 
                 input_dim: int,
                 output_dim: int,
                 n_heads: int,
                 hidden_dim: Optional[int] = 64,
                 n_hidden_layers: Optional[int] = 2,
                 delta = 0.3, # parameter for bootstrapping selection of heads in Thompson Sampling 
                 torso: Optional[nn.Module] = None): 
                 
        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.n_hidden_layers = n_hidden_layers
        self.n_heads = n_heads
        self.delta = delta

        if torso is None:
            assert (
                n_hidden_layers is not None and n_hidden_layers > 0
            ), "n_hidden_layers must be >=0"
            
            # Create the torso of the NN
            self.torso = [nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Dropout(0.2)]
            for _ in range(n_hidden_layers - 1):
                self.torso.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(0.2)])
            self.torso = nn.Sequential(*self.torso)
            self.torso.hidden_dim = hidden_dim
        else:
            self.torso = torso
            assert self.torso[0].in_features == input_dim, "Input dim does not match torso input dim"

        self.heads = nn.ModuleList([nn.Linear(self.torso.hidden_dim, output_dim) for _ in range(n_heads)])

    def get_heads_to_include(self):
        heads_to_include = None
        while heads_to_include is None:
            # generate self.n_heads random numbers between 0 and 1
            random_numbers = torch.rand(self.n_heads)
            # select the heads that are less than delta
            heads_to_include = [i for i in range(self.n_heads) if random_numbers[i] < self.delta]
            if len(heads_to_include) == 0:
                heads_to_include = None  # No heads were selected

        return heads_to_include
    
    def get_random_head(self):
        return torch.randint(0, self.n_heads, (1,)).item()

    def forward(self, x, head_index):
        x = self.torso(x)

        return self.heads[head_index](x)

        
